{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from typing import List, Optional, Union, Dict\n",
    "from sentencepiece import SentencePieceProcessor\n",
    "from transformers import PreTrainedTokenizer\n",
    "from transformers.utils import logging, PaddingStrategy\n",
    "from transformers.tokenization_utils_base import EncodedInput, BatchEncoding\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class SPTokenizer:\n",
    "    def __init__(self, model_path: str):\n",
    "        # reload tokenizer\n",
    "        assert os.path.isfile(model_path), model_path\n",
    "        self.sp_model = SentencePieceProcessor(model_file=model_path)\n",
    "\n",
    "        # BOS / EOS token IDs\n",
    "        self.n_words: int = self.sp_model.vocab_size()\n",
    "        self.bos_id: int = self.sp_model.bos_id()\n",
    "        self.eos_id: int = self.sp_model.eos_id()\n",
    "        self.pad_id: int = self.sp_model.eos_id()\n",
    "        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()\n",
    "\n",
    "        special_tokens = [\"[MASK]\", \"[gMASK]\", \"[sMASK]\", \"sop\", \"eop\"]\n",
    "        self.special_tokens = {}\n",
    "        self.index_special_tokens = {}\n",
    "        for token in special_tokens:\n",
    "            self.special_tokens[token] = self.n_words\n",
    "            self.index_special_tokens[self.n_words] = token\n",
    "            self.n_words += 1\n",
    "\n",
    "    def tokenize(self, s: str):\n",
    "        return self.sp_model.EncodeAsPieces(s)\n",
    "\n",
    "    def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:\n",
    "        assert type(s) is str\n",
    "        t = self.sp_model.encode(s)\n",
    "        if bos:\n",
    "            t = [self.bos_id] + t\n",
    "        if eos:\n",
    "            t = t + [self.eos_id]\n",
    "        return t\n",
    "\n",
    "    def decode(self, t: List[int]) -> str:\n",
    "        return self.sp_model.decode(t)\n",
    "\n",
    "    def decode_tokens(self, tokens: List[str]) -> str:\n",
    "        text = self.sp_model.DecodePieces(tokens)\n",
    "        return text\n",
    "\n",
    "    def convert_token_to_id(self, token):\n",
    "        \"\"\" Converts a token (str) in an id using the vocab. \"\"\"\n",
    "        if token in self.special_tokens:\n",
    "            return self.special_tokens[token]\n",
    "        return self.sp_model.PieceToId(token)\n",
    "\n",
    "    def convert_id_to_token(self, index):\n",
    "        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n",
    "        if index in self.index_special_tokens:\n",
    "            return \"\"\n",
    "        return self.sp_model.IdToPiece(index)\n",
    "\n",
    "sptokenizer = SPTokenizer(model_path=\"raw_model/chinese_vocab/tokenizer.model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "all_list = []\n",
    "\n",
    "for i in tqdm(range(70000)):\n",
    "    try:\n",
    "        v = sptokenizer.convert_id_to_token(i)\n",
    "        all_list.extend([v])\n",
    "    except Exception as e:\n",
    "        break\n",
    "\n",
    "len(all_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_list[40000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(all_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_list.index('你')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def _is_chinese_char(cp):\n",
    "    if (\n",
    "        (cp >= 0x4E00 and cp <= 0x9FFF)\n",
    "        or (cp >= 0x3400 and cp <= 0x4DBF)  #\n",
    "        or (cp >= 0x20000 and cp <= 0x2A6DF)  #\n",
    "        or (cp >= 0x2A700 and cp <= 0x2B73F)  #\n",
    "        or (cp >= 0x2B740 and cp <= 0x2B81F)  #\n",
    "        or (cp >= 0x2B820 and cp <= 0x2CEAF)  #\n",
    "        or (cp >= 0xF900 and cp <= 0xFAFF)\n",
    "        or (cp >= 0x2F800 and cp <= 0x2FA1F)  #\n",
    "    ):  #\n",
    "        return True\n",
    "\n",
    "    return False\n",
    "\n",
    "\n",
    "def is_chinese(word: str):\n",
    "    # word like '180' or '身高' or '神'\n",
    "    for char in word:\n",
    "        char = ord(char)\n",
    "        if not _is_chinese_char(char):\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "chinese_list = list(set(filter(is_chinese, all_list)))\n",
    "chinese_list[:5],len(chinese_list)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "\n",
    "data = pd.DataFrame({'word':chinese_list}).pipe(\n",
    "    lambda x: x.assign(**{\n",
    "        'len':x['word'].apply(lambda j: len(j))\n",
    "    }).query('len > 0')\n",
    ")\n",
    "data.pipe(\n",
    "    lambda x:x.groupby(['len']).agg(\n",
    "        freq = ('word', 'count')\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.query('len <= 10').sort_values(by=['len'], ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chinese_list_finally = data.query('len <= 10')['word'].tolist()\n",
    "chinese_list_finally[:5], len(chinese_list_finally)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 处理model和tokenizer\n",
    "\n",
    "## tokenizer "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from transformers import AutoTokenizer\n",
    "from transformers import AutoModelForCausalLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "old_model_name_or_path = \"raw_model/mpt-7b-chat\"\n",
    "\n",
    "old_tokenizer = AutoTokenizer.from_pretrained(old_model_name_or_path)\n",
    "old_model = AutoModelForCausalLM.from_pretrained(old_model_name_or_path, trust_remote_code=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "50279"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(old_tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "80889"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "old_tokenizer.add_tokens(chinese_list_finally)\n",
    "\n",
    "len(old_tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "80896"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "int(len(old_tokenizer) // 64 + 1) * 64"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Embedding(50432, 4096)"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "old_model.get_input_embeddings()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Embedding(80896, 4096)"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "old_model.resize_token_embeddings(80896)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Embedding(80896, 4096)"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "old_model.get_input_embeddings()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "tokenizer config file saved in mpt_chat_7b_chinese_no_train/tokenizer_config.json\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens file saved in mpt_chat_7b_chinese_no_train/special_tokens_map.json\n",
      "Configuration saved in mpt_chat_7b_chinese_no_train/config.json\n",
      "Configuration saved in mpt_chat_7b_chinese_no_train/generation_config.json\n",
      "The model is bigger than the maximum size per checkpoint (4GB) and is going to be split in 8 checkpoint shards. You can find where each parameters has been saved in the index located at mpt_chat_7b_chinese_no_train/pytorch_model.bin.index.json.\n"
     ]
    }
   ],
   "source": [
    "new_model_name_or_path = \"mpt_chat_7b_chinese_no_train\"\n",
    "\n",
    "old_tokenizer.save_pretrained(new_model_name_or_path)\n",
    "old_model.save_pretrained(new_model_name_or_path,max_shard_size=\"4GB\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hz_net",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
